import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from PIL import Image
import cv2
import os
from blind_deconv import blind_deconv
from ringing_artifacts_removal import ringing_artifacts_removal
from misc import visualize_rgb ,visualize_image, gray_image, process_image,PSNR
from metrics import psnr
import time
import sys


def main():
    kernel_list = [41, 45, 49, 53, 57, 61, 65, 69, 73, 87]
    prefix = 'saturated'
    
    for coco in range(10):
        for img_idx in range(5):
            for ker_idx in range(4):
                    
                image_path = f'images/Lai/{prefix}_0{img_idx+1}_kernel_0{ker_idx+1}.png'
                # print(image_path)
                kernel_size = 77
                opts = {
                    'prescale': 1,   # Downsampling
                    'xk_iter': 5,    # Iterations
                    'gamma_correct': 1.0,
                    'k_thresh': 20,
                    'kernel_size': kernel_list[coco],
                }
                
                lambda_dark = 4e-3
                lambda_ftr = 3.5e-4
                lambda_grad = 5e-3
                lambda_tv = 1e-3
                lambda_l0 = 2e-4
                weight_ring = 0

                results_dir = 'results'
                os.makedirs(results_dir, exist_ok=True)

                # image = cv2.imread(image_path)

                inpt = Image.open(image_path)
                yg = gray_image(inpt)
                # print("FKJSDHF")
                start_time = time.time()
                kernel, interim_latent = blind_deconv(yg, lambda_ftr, lambda_dark, lambda_grad, opts)
                end_time = time.time()
                print(f"Time taken: {end_time - start_time} seconds")

                saturation = 0
                if not saturation:
                    y = process_image(Image.open(image_path))
                    y = y.permute(1, 2, 0) / 255.0
                    Latent = ringing_artifacts_removal(y, kernel, lambda_tv, lambda_l0, weight_ring)

                Latent[Latent > 1.0] = 1.0
                Latent[Latent < 0.0] = 0.0
                Latent = (Latent.squeeze() * 255.0).numpy().astype('uint8')
                Latent = Image.fromarray(Latent)
                Latent.save(os.path.join(results_dir, f'{prefix}_0{img_idx+1}_kernel_0{ker_idx+1}_{coco}.png'))

                kmn, kmx = kernel.min(), kernel.max()
                kernel = ((kernel - kmn) / (kmx - kmn) * 255.0).numpy().astype('uint8')
                kernel = Image.fromarray(kernel)
                kernel.save(os.path.join(results_dir, f'{prefix}_0{img_idx+1}_kernel_0{ker_idx+1}_kernel_{coco}.png'))

        
               


if __name__ == "__main__":
    main()
